'''
Author: SlytherinGe
LastEditTime: 2021-10-22 17:49:55
'''
from cv2 import kmeans
from sklearn.cluster import KMeans
from sklearn.feature_extraction import image
import mmcv
import numpy as np
import matplotlib.pyplot as plt
# (x, y) pos
IMG_ROOT = '/media/gejunyao/Disk1/Datasets/CAS-OpenSARShip/ship_detection_online/VOC2012/JPEGImages/'
IMG_ID = 'Gao_ship_hh_020180314670301805'
N_CLUSTER = 16
def load_DFL_vec_to_numpy(img_id):

    # load DFL vector
    DFL_vec = mmcv.load('/home/gejunyao/fast_cache/DFL_Cache/'+ img_id + '.pkl')
    # transfer DFL from (vec, h, w) to (h, w, vec)
    DFL_vec = DFL_vec.permute(1, 2, 0).numpy()
    return DFL_vec

if __name__ == '__main__':

    DFL_vec = load_DFL_vec_to_numpy(IMG_ID)
    h, w, vec_len = DFL_vec.shape
    # cluster start here

    flatten_DFL_vec = DFL_vec.reshape(-1, vec_len)
    km = KMeans(n_clusters=N_CLUSTER).fit(flatten_DFL_vec)
    labels = km.labels_
    clustered_DFL = labels.reshape(h, w)
    # read raw image
    img = mmcv.imread(IMG_ROOT + IMG_ID + '.jpg')
    rescale_img = mmcv.imrescale(img, (h, w))
    padded_img = mmcv.impad(rescale_img, shape=(h, w))
    # visualization
    plt.figure()
    plt.suptitle('cluster result for {}, n_clusters={}'.format(IMG_ID, N_CLUSTER))
    plt.subplot(1,3,1)
    plt.imshow(padded_img)
    plt.title('raw image')  
    plt.subplot(1,3,2)                                                           
    plt.imshow(clustered_DFL)
    plt.title('clustered DFL')
    plt.subplot(1,3,3)
    plt.imshow(padded_img)
    plt.imshow(clustered_DFL, alpha=0.3)
    plt.title('fused image')
    plt.show()

